""" Get the state transitions.

Note: as an intermediate step, runs STATA code: src/models_new/state_transitions_stata.do

"""

import os
import itertools
import pandas as pd
import numpy as np
import time


#%% DEFINE A FUNCTION -------------------------------------------------------------------
def build_table_transitions(df_all):
    params = dict()
    errors = dict()
    rows = [
        'low',
        'mid',
        'high',
        'price'
    ]
    cols = [
        'price',
        'low',
        'mid',
        'high',
        'constant',
        'sigma'
    ]

    for row, col in itertools.product(rows, cols):
        try:
            params['param_' + row + '_' + col] = round(df_all[row][col]['b'], 2)
            errors['error_' + row + '_' + col] = round(df_all[row][col]['se'], 2)
        except:
            pass

    '''
    tex_input = {**params, **errors}
    with open('./src/tex/table_state_transitions.tex', 'r') as f:
        tex = f.read()
        output = tex.format(**tex_input)
    with open('./reports/revision/tables/table_state_transitions.tex', 'w') as f:
        f.write(output)
    '''

    # Get the matrices
    r = np.array(
        [
            [
                params['param_price_price'],
                0,
                0,
                0
            ],
            [
                params['param_low_price'],
                params['param_low_low'],
                params['param_low_mid'],
                params['param_low_high']
            ],
            [
                params['param_mid_price'],
                params['param_mid_low'],
                params['param_mid_mid'],
                params['param_mid_high']
            ],
            [
                params['param_high_price'],
                params['param_high_low'],
                params['param_high_mid'],
                params['param_high_high']
            ],
        ]
    )

    const = np.array(
        [
            [
                params['param_price_constant']
            ],
            [
                params['param_low_constant']
            ],
            [
                params['param_mid_constant']
            ],
            [
                params['param_high_constant']
            ],
        ]
    )

    sigma = params['param_price_sigma']

    return r, const, sigma, params, errors

#%% GET THE COEFFS IN A NICE FORM -------------------------------------------------------
df_all_by_time = dict()
for t in ['fortnight', 'month']:
    df_all_by_time[t] = dict()
    for spec in ['low', 'mid', 'high']:
        df_all_by_time[t][spec] = pd.read_excel(
            f'./models/first_stage/transitions_stata/transitions_{spec}_{t}.xlsx',
            index_col=0,
            names=['index', 'price', 'low', 'mid', 'high', 'constant', 'sigma']
        )
    df_all_by_time[t]['price'] = pd.read_excel(
        f'./models/first_stage/transitions_stata/transitions_price_{t}.xlsx',
        index_col=0,
        names=['index', 'price', 'constant', 'sigma']
    )
    r, const, sigma, params, errors = build_table_transitions(df_all_by_time[t])
    pd.DataFrame(r).to_csv(f'./models/first_stage/r_{t}.csv')
    pd.DataFrame(const).to_csv(f'./models/first_stage/const_{t}.csv')
    pd.DataFrame([sigma]).to_csv(f'./models/first_stage/sigma_{t}.csv')
    pd.Series(params).to_csv(f'./models/first_stage/params_{t}.csv')
    pd.Series(errors).to_csv(f'./models/first_stage/errors_{t}.csv')
